Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCF][PIPELINE] Handle the case when values from the peeled prologue may escape out of the loop #105755

Merged

Conversation

pawelszczerbuk
Copy link
Contributor

Previously the values in the peeled prologue that weren't treated with the predicateFn were passed to the loop body without any other predication. If those values are later used outside of the loop body, they may be incorrect if the num iterations is smaller than num stages - 1. We need similar masking for those, as is done in the main loop body, using already existing predicates.

@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: None (pawelszczerbuk)

Changes

Previously the values in the peeled prologue that weren't treated with the predicateFn were passed to the loop body without any other predication. If those values are later used outside of the loop body, they may be incorrect if the num iterations is smaller than num stages - 1. We need similar masking for those, as is done in the main loop body, using already existing predicates.


Full diff: https://github.com/llvm/llvm-project/pull/105755.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+15-5)
  • (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+17-9)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index cc1a22d0d48a18..d8e1cc0ecef88e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -268,7 +268,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
 }
 
 void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
-  // Initialize the iteration argument to the loop initiale values.
+  // Initialize the iteration argument to the loop initial values.
   for (auto [arg, operand] :
        llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
     setValueMapping(arg, operand.get(), 0);
@@ -320,16 +320,26 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
-        setValueMapping(op->getResult(destId), newOp->getResult(destId),
-                        i - stages[op]);
+        Value source = newOp->getResult(destId);
         // If the value is a loop carried dependency update the loop argument
-        // mapping.
         for (OpOperand &operand : yield->getOpOperands()) {
           if (operand.get() != op->getResult(destId))
             continue;
+          if (predicates[predicateIdx] &&
+              !forOp.getResult(operand.getOperandNumber()).use_empty()) {
+            // If the value is used outside the loop, we need to make sure we
+            // return the correct version of it.
+            Value prevValue = valueMapping
+                [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
+                [i - stages[op]];
+            source = rewriter.create<arith::SelectOp>(
+                loc, predicates[predicateIdx], source, prevValue);
+          }
           setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
-                          newOp->getResult(destId), i - stages[op] + 1);
+                          source, i - stages[op] + 1);
         }
+        setValueMapping(op->getResult(destId), newOp->getResult(destId),
+                        i - stages[op]);
       }
     }
   }
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 46e7feca4329ee..9687f80f5ddfc8 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -703,18 +703,26 @@ func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
 // -----
 
 // NOEPILOGUE-LABEL: stage_0_value_escape(
-func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
+func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub: index) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  %c4 = arith.constant 4 : index
   %cf = arith.constant 1.0 : f32
-// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
-// NOEPILOGUE: %[[A:.+]] = arith.addf
-// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
-// NOEPILOGUE:   %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
-// NOEPILOGUE:   %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
-// NOEPILOGUE:   scf.yield %[[S]]
-  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+// NOEPILOGUE: %[[UB:[^,]+]]: index)
+// NOEPILOGUE-DAG: %[[C0:.+]] = arith.constant 0 : index
+// NOEPILOGUE-DAG: %[[C1:.+]] = arith.constant 1 : index
+// NOEPILOGUE-DAG: %[[CF:.+]] = arith.constant 1.000000e+00
+// NOEPILOGUE: %[[CND0:.+]] = arith.cmpi sgt, %[[UB]], %[[C0]]
+// NOEPILOGUE: scf.if
+// NOEPILOGUE: %[[IF:.+]] = scf.if %[[CND0]]
+// NOEPILOGUE:   %[[A:.+]] = arith.addf
+// NOEPILOGUE:   scf.yield %[[A]]
+// NOEPILOGUE: %[[S0:.+]] = arith.select %[[CND0]], %[[IF]], %[[CF]]
+// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[S0]],
+// NOEPILOGUE:   %[[UB_1:.+]] = arith.subi %[[UB]], %[[C1]] : index
+// NOEPILOGUE:   %[[CND1:.+]] = arith.cmpi slt, %[[IV]], %[[UB_1]] : index
+// NOEPILOGUE:   %[[S1:.+]] = arith.select %[[CND1]], %{{.+}}, %[[ARG]] : f32
+// NOEPILOGUE:   scf.yield %[[S1]]
+  %r = scf.for %i0 = %c0 to %ub step %c1 iter_args(%arg0 = %cf) -> (f32) {
     %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
     %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
     memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, thanks!

@ThomasRaoux ThomasRaoux merged commit 7c90081 into llvm:main Aug 23, 2024
11 checks passed
cjdb pushed a commit to cjdb/llvm-project that referenced this pull request Aug 23, 2024
…may escape out of the loop (llvm#105755)

Previously the values in the peeled prologue that weren't treated with
the `predicateFn` were passed to the loop body without any other
predication. If those values are later used outside of the loop body,
they may be incorrect if the num iterations is smaller than num stages -
1. We need similar masking for those, as is done in the main loop body,
using already existing predicates.
5chmidti pushed a commit that referenced this pull request Aug 24, 2024
…may escape out of the loop (#105755)

Previously the values in the peeled prologue that weren't treated with
the `predicateFn` were passed to the loop body without any other
predication. If those values are later used outside of the loop body,
they may be incorrect if the num iterations is smaller than num stages -
1. We need similar masking for those, as is done in the main loop body,
using already existing predicates.
dmpolukhin pushed a commit to dmpolukhin/llvm-project that referenced this pull request Sep 2, 2024
…may escape out of the loop (llvm#105755)

Previously the values in the peeled prologue that weren't treated with
the `predicateFn` were passed to the loop body without any other
predication. If those values are later used outside of the loop body,
they may be incorrect if the num iterations is smaller than num stages -
1. We need similar masking for those, as is done in the main loop body,
using already existing predicates.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants